import numpy as np
import itertools
import matplotlib.pyplot as plt
def value_iteration(p_s_given_sa, pi_a_given_s, r_sa,num_states,num_actions,gamma):
    Q = np.zeros([num_states, num_actions])
    for t in range(1000):
        V = np.einsum('ij,ij->i', pi_a_given_s, Q)
        V_next = np.einsum('ijk,k->ij', p_s_given_sa, V)
        Q_new = r_sa + gamma * V_next
        diff = np.linalg.norm(Q - Q_new, ord=np.inf)
        if diff < 1e-6:
            # print('Value iteration converged in %d iterations' % t)
            break
        Q = Q_new.copy()
    return Q
def compute_initial_state_value(num_states,num_actions, Q,dynamics,reward):
    pi_a_given_s_eval = np.zeros([num_states, num_actions])
    pi_a_given_s_eval[np.arange(num_states), np.argmax(Q, axis=1)] = 1
    # value of policy on real dynamics w/ real reward
    Q_exact = value_iteration(dynamics, pi_a_given_s_eval, reward, num_states, num_actions, 0.9)
    V_exact = np.einsum('ij,ij->i', pi_a_given_s_eval, Q_exact)
    value = V_exact[0]
    return value

def compute_optimal(num_states,num_actions,dynamics,reward):
    Q = np.zeros([num_states, num_actions])
    pi_a_given_s_eval = np.zeros([num_states, num_actions])
    pi_a_given_s_eval[np.arange(num_states), np.argmax(Q, axis=1)] = 1
    # value of policy on real dynamics w/ real reward
    for i in range(100):
        Q = value_iteration(dynamics, pi_a_given_s_eval, reward, num_states, num_actions, 0.9)
        pi_a_given_s_eval = np.zeros([num_states, num_actions])
        pi_a_given_s_eval[np.arange(num_states), np.argmax(Q, axis=1)] = 1
    V_exact = np.einsum('ij,ij->i', pi_a_given_s_eval, Q)
    value = V_exact[0]
    return value


